# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Generator, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union

from megatron.energon.flavors.base_dataset import (
    FlexState,
    Sample,
    SavableDataset,
    add_sample_restore_key,
)
from megatron.energon.savable import Savable
from megatron.energon.worker import WorkerConfig

T = TypeVar("T")
T_sample = TypeVar("T_sample", covariant=True)
T_sample_out = TypeVar("T_sample_out", covariant=True)
T_sample_in = TypeVar("T_sample_in", covariant=True)


class BaseWrapperDataset(SavableDataset[T_sample_out], Generic[T_sample_in, T_sample_out], ABC):
    """Base class for dataset wrappers. All dataset wrappers should derive from this. A dataset
    wrapper takes one dataset and modifies its samples to make a new dataset. This can be for
    shuffling samples or applying custom functions to the data. Some wrappers only modify the
    length of the dataset or how it's repeated."""

    datasets: Tuple[SavableDataset[T_sample_in], ...]

    def __init__(
        self,
        datasets: Union[SavableDataset[T_sample_in], Iterable[SavableDataset[T_sample_in]]],
        *,
        worker_config: WorkerConfig,
    ):
        super().__init__(worker_config=worker_config)

        if isinstance(datasets, SavableDataset):
            self.datasets = (datasets,)
        else:
            self.datasets = tuple(datasets)

        for d in self.datasets:
            # Check that the dataset worker configs are the same as the wrapper worker config
            assert d.worker_config == self.worker_config, (
                "Dataset and wrapper worker configs must match."
            )

    @property
    def dataset(self) -> SavableDataset:
        """Convenience property, if only one dataset is wrapped."""

        assert len(self.datasets) == 1
        return self.datasets[0]

    def can_restore_sample(self) -> bool:
        return all(ds.can_restore_sample() for ds in self.datasets)

    def assert_can_restore(self) -> None:
        for ds in self.datasets:
            ds.assert_can_restore()

    def worker_has_samples(self) -> bool:
        return any(ds.worker_has_samples() for ds in self.datasets)

    def _find_wrapped_dataset(self, cls: Type[SavableDataset]) -> Optional[SavableDataset]:
        """Find the outermost dataset wrapped in this dataset that is of type cls."""

        for ds in self.datasets:
            if isinstance(ds, cls):
                return ds
            elif isinstance(ds, BaseWrapperDataset):
                res = ds._find_wrapped_dataset(cls)
                if res is not None:
                    return res
        return None

    def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample_out:
        if len(self.datasets) == 1:
            return self.datasets[0].restore_sample(index)
        else:
            id, ds_idx = index[:2]
            assert id == type(self).__name__
            index = index[2:]
            assert isinstance(ds_idx, int)
            return add_sample_restore_key(
                self.datasets[ds_idx].restore_sample(index),
                ds_idx,
                src=self,
            )

    def save_state(self) -> FlexState:
        own_state = super().save_state()

        return FlexState(datasets=[ds.save_state() for ds in self.datasets], **own_state)

    def restore_state(self, state: FlexState) -> None:
        assert len(self.datasets) == len(state["datasets"])
        for dataset, dstate in zip(self.datasets, state["datasets"]):
            dataset.restore_state(dstate)

        super().restore_state(state)

    def reset_state_deep(self) -> None:
        """Resets the state of the inner datasets and then the own state."""

        for ds in self.datasets:
            if isinstance(ds, BaseWrapperDataset):
                ds.reset_state_deep()
            else:
                ds.reset_state_own()

        self.reset_state_own()

    @abstractmethod
    def reset_state_own(self) -> None:
        """Resets the state of the dataset, excl. the inner datasets."""
        ...


class SampleIndex(Savable):
    """A simple class to hold the sample index for one worker."""

    worker_config: WorkerConfig
    current_idx: int

    actives = 0

    def __init__(self, worker_config: WorkerConfig, *, src: Any) -> None:
        self.worker_config = worker_config
        self.current_idx = 0
        self.src = src

    def get_next(self) -> int:
        res = self.current_idx
        self.current_idx += 1
        return res

    @contextmanager
    def ctx(self, sample_idx: Optional[int] = None):
        if sample_idx is None:
            sample_idx = self.get_next()
        assert WorkerConfig.active_worker_config is not None
        WorkerConfig.active_worker_config.worker_push_sample_index(sample_idx)
        # print("  " * SampleIndex.actives + f"Activated from {type(self.src).__name__}({id(self.src)}) {sample_idx} -> {WorkerConfig.active_worker_config._sample_index_stack}")
        SampleIndex.actives += 1
        try:
            yield sample_idx
        finally:
            assert WorkerConfig.active_worker_config is not None
            popped = WorkerConfig.active_worker_config.worker_pop_sample_index()
            SampleIndex.actives -= 1
            # print("  " * SampleIndex.actives + f"Deactivate from {type(self.src).__name__}({id(self.src)}) {sample_idx} -> {WorkerConfig.active_worker_config._sample_index_stack}")
            assert popped == sample_idx, f"Expected {sample_idx}, got {popped}"

    def iter_ctx(
        self,
        it: Iterable[T_sample],
        sample_idx: Optional[int] = None,
    ) -> Generator[Tuple[int, T_sample], None, None]:
        it = iter(it)
        try:
            while True:
                try:
                    with self.ctx(sample_idx) as res_sample_idx:
                        x = next(it)
                    yield res_sample_idx, x
                except StopIteration:
                    break
        finally:
            if hasattr(it, "close"):
                it.close()

    def save_state(self) -> int:
        return self.current_idx

    def restore_state(self, state: Optional[int]) -> None:
        if state is None:
            self.current_idx = 0
        else:
            self.current_idx = state


def get_sample_restore_key(sample: Any) -> Optional[Union[str, int]]:
    """Gets the restore key from an arbitrary sample."""
    if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"):
        return sample.__restore_key__
    elif isinstance(sample, dict) and "__restore_key__" in sample:
        return sample["__restore_key__"]
    else:
        return None
